package com.auditlog.datasource.execute;

import com.auditlog.datasource.StatementProxy;
import com.auditlog.datasource.context.ConnectionContext;
import com.auditlog.datasource.db.SqlType;
import com.auditlog.sql.factory.SqlRunnerFactoryRegistry;
import com.auditlog.datasource.struct.SqlMeta;
import com.auditlog.sql.runner.SqlRunner;
import com.auditlog.sql.runner.SqlRunnerChain;
import com.auditlog.datasource.context.ContextHolder;
import com.auditlog.datasource.struct.TupleTwo;
import com.auditlog.exception.NotSupportException;
import lombok.extern.slf4j.Slf4j;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;

@Slf4j
public class MultiExecutor<T, S extends Statement> extends BaseExecutor<T, S> {

    private SqlRunnerChain sqlRunnerChain;

    public MultiExecutor(StatementProxy<S> statementProxy, StatementCallback<T, S> statementCallback, SqlMeta sqlMeta) {
        super(statementProxy, statementCallback, sqlMeta);
    }

    @Override
    public void beforeExecute() throws SQLException {
        ConnectionContext connectionContext = ContextHolder.getConnectionContext();
        if (!connectionContext.isError()) {
            // 防止一个sql中包含多个sql执行语句的情况，且这些sql包含insert on duplicate key
            List<SqlType> sqlTypes = this.getSqlMeta().getSqlTypes();
            if (connectionContext.isPrepareMultiSql() || connectionContext.isNonPrepareMultiSql()) {
                if (sqlTypes.contains(SqlType.INSERT_DUPLICATE)) {
                    throw new NotSupportException("一个sql中包含多个sql语句，且包含的sql含有insert操作");
                }
            }
            List<net.sf.jsqlparser.statement.Statement> statements = this.getSqlMeta().getStatements();
            List<TupleTwo<net.sf.jsqlparser.statement.Statement, SqlRunner>> tuples = new ArrayList<>();
            for (int i = 0; i < statements.size(); i++) {
                SqlType sqlType = sqlTypes.get(i);
                SqlRunner sqlRunner = SqlRunnerFactoryRegistry.getInstance().get(sqlType).getRunner(this.getStatementProxy(), statements.get(i));
                sqlRunner.validate();
                tuples.add(TupleTwo.<net.sf.jsqlparser.statement.Statement, SqlRunner>builder().v1(statements.get(i)).v2(sqlRunner).build());
            }
            this.sqlRunnerChain = new SqlRunnerChain(this.getStatementProxy(), tuples);
            try {
                sqlRunnerChain.beforeExecute();
            } catch (Exception e) {
                handleException(connectionContext, e);
            }
        }
    }

    @Override
    public void afterExecute() throws SQLException {
        ConnectionContext connectionContext = ContextHolder.getConnectionContext();
        if (sqlRunnerChain != null) {
            if (!connectionContext.isError()) {
                try {
                    sqlRunnerChain.afterExecute();
                } catch (Exception e) {
                    handleException(connectionContext, e);
                }
            }
            sqlRunnerChain.executeResult().forEach(recordImage -> ContextHolder.getExecuteContext().addRecordImage(recordImage));
        }
    }
}
